import torch
from torch import nn, Tensor
import os
import re
from collections import defaultdict
from tqdm import tqdm
import pandas as pd
import yaml
import multiprocessing as mp
import json
import argparse


import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from matplotlib.path import Path
from scipy.spatial import ConvexHull
from scipy.spatial.distance import directed_hausdorff
from scipy.stats import wasserstein_distance
from geotorch.sphere import uniform_init_sphere_ as unif_sphere
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

from src.sumformer import *
from eval_utils import *



def extract_hyperparameters(filepath):
    # Define the regular expression pattern
    pattern = r"(?P<name>[a-z-]+)(?P<value>\d+)"
    
    # Use defaultdict to store each hyperparameter as a list of values
    hyperparameters = defaultdict(list)
    
    # Use finditer to get all matches in the filepath
    matches = re.finditer(pattern, filepath)
    
    # Iterate through matches and store them in the dictionary
    for match in matches:
        name = match.group("name")  # hyperparameter name
        value = int(match.group("value"))  # convert value to integer
        hyperparameters[name].append(value)
    
    return dict(hyperparameters)

def eval_ellipse_tf(fp, raw_data, params = None, device = 'cpu'):

    if params is None:
        # print('Extracting params from fp')
        params = extract_hyperparameters(fp)
        depth = params['depth-'][0]
        ed = params['-ed-'][0]
        hd = params['-hd-'][0]
        if in_dim == 2:
             od = params['-od'][0]
        else:
            od = params['-od-'][0]
        tod = params['-tod-'][0]
        heads = params['-heads-'][0]

        model = ConvexHullNNTransformer(depth=depth, num_heads=heads, embedding_dim=ed, hidden_dim=hd, 
                                input_dim=in_dim, transformer_output_dim=tod, output_dim=od)
    
    else:
        model = ConvexHullNNTransformer(**params)

    
    ptsets = [raw_data[i][:,:-1] for i in range(len(raw_data))]
    dataloader, gt = npz_to_batches(raw_data, 128)
    
    

    
    if '.pt' not in fp:
        state_dict_path = os.path.join(fp, 'final_model.pt')
    else:
        state_dict_path = fp
    
    model.load_state_dict(torch.load(state_dict_path), strict = False)
    model = model.to(device)
    
    chull = []
    gt_hulls = []
    outputs = []
    
    for batch in dataloader:
        n = batch.n_nodes[0].item() #todo: assuming constant ptset size throughout batch

        batch = batch.to(device)
        out = model(batch)
        out = out.view(-1, n, out.data.size(-1))
        out = F.softmax(out, dim=1)
        
        outputs.extend(out.data.cpu().detach()) #storing for later access
        out = out.view(-1, out.data.size(-1))
        
        chull += [tensor.cpu().detach().numpy() for tensor in get_approx_chull(out, batch)]
    
    for batch in gt:
        gt_hulls += [tensor.cpu().detach().numpy() for tensor in batch]

    
    dir_width, std = avg_err(chull, ptsets, in_dim=3)
    #wasserstein_dist = avg_wasserstein_nd(chull, gt_hulls)

    return dir_width, std#, wasserstein_dist





def get_models(yml_file_path):
    with open(yml_file_path, 'r') as file:
        data = yaml.safe_load(file)
    
    return list(data.items())  # Return a list of (model_name, config_dict) tuples


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment', type=str)
    parser.add_argument('--model_specs', type=str)
    parser.add_argument('--train', type=str)
    args = parser.parse_args()

    experiment = args.experiment
    train = args.train
    filepaths = get_models(args.model_specs)


    # thin_ellipses = np.load('../../../../../data/oren/coreset/data/2d_ellipse_500_test.npy')
    # circles = np.load('../../../../../data/oren/coreset/data/2d_uniform_500_test.npy')
    # gauss = np.load('../../../../../data/oren/coreset/data/single_gauss_2d_test.npy')
    # mixed_gauss = np.load('../../../../../data/oren/coreset/data/mix_gauss_2d_test.npy')
    # fish = np.load('../../../../../data/oren/coreset/data/upsampled_fish.npy')


    all_datasets = {
        "thin_ellipses": np.load('../../../../../data/oren/coreset/data/3d_ellipse_500_test.npy'),
        "circles": np.load('../../../../../data/oren/coreset/data/3d_uniform_500_test.npy'),
        "gauss": np.load('../../../../../data/oren/coreset/data/single_gauss_3d_test.npy'),
        "mixed_gauss": np.load('../../../../../data/oren/coreset/data/mix_gauss_3d_test.npy'),
        "manifold": np.load('../../../../../data/oren/coreset/data/3d_manifold_ellipse_test.npy'),
        "modelnet": np.load('../../../../../data/oren/coreset/data/scaled_subsampled_modelnet_coreset_test.npy'),
        "mixed": np.load(f'/data/oren/coreset/data/mixed_3d_500_test.npy')
    }

    if train == "modelnet" or train == "mixed":
        mixed = np.load(f'/data/oren/coreset/data/mixed_3d_500_test.npy')
    else:
        mixed = np.load(f'/data/oren/coreset/data/mixed_3d_no_{train}.npy')

    in_dist = all_datasets[train]
    modelnet = all_datasets["modelnet"]

    ## Evaluating models
    errs = {}
    ood_errs = {}
    modelnet_errs = {}

    
    for (fp, params) in tqdm(filepaths):
        print(f'Evaluating {fp}')
        read = os.path.join('/data/oren/coreset/models/elliptical-50/ConvexHullNNTransformer/direction/', fp, experiment, 'final_model.pt') ## overparameterized/higher od, trained on uniform data


        errs[fp] = eval_ellipse_tf(read, in_dist, params)
        ood_errs[fp] = eval_ellipse_tf(read, mixed, params)
        modelnet_errs[fp] = eval_ellipse_tf(read, modelnet, params)



    try:
        data = {
        'model': list(errs.keys()),
        'experiment': experiment,
        f'{train} directional error': [val[0] for val in errs.values()],
        f'{train} std': [val[1] for val in errs.values()],
        'ood synthetic directional error': [val[0] for val in ood_errs.values()],
        'ood synthetic std': [val[1] for val in ood_errs.values()],
        'modelnet directional error': [val[0] for val in modelnet_errs.values()],
        'modelnet std': [val[1] for val in modelnet_errs.values()]
        }
        
        print(data)

        df = pd.DataFrame(data)
        df.to_csv(f'/data/oren/coreset/out/{experiment}_abbrev_chull_results.csv', index = False)



        # df.to_csv('/data/oren/coreset/out/modelnet_highod_spheretrain_chull_results.csv', index = False)
        # df.to_csv('/data/oren/coreset/out/modelnet_highod_tf_spheretrain_chull_results.csv', index = False)
        # df.to_csv('/data/oren/coreset/out/overparam_3d_chull_results.csv', index = False)




    except:

        print('Exception')

        # out_file = open("/data/oren/coreset/out/3d_circle_errs.json", "w")
        # json.dump(errs, out_file)
        # out_file.close()

        # out_file = open("/data/oren/coreset/out/3d_ellipse_errs.json", "w")
        # json.dump(ellipse_errs, out_file)
        # out_file.close()

        # out_file = open("/data/oren/coreset/out/3d_box_errs.json", "w")
        # json.dump(ellipse_errs, out_file)
        # out_file.close()

if __name__ == "__main__":
    main()